{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# The `SolveAiyagari` notebook <a id=\"Aiyagari\"></a>[<font size=1>(back to `Main.ipynb`)</font>](./Main.ipynb)\n",
    "\n",
    "This notebook gathers all functions related to the resolution of the Aiyagar model (fiscal policy is exogenous)."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## This notebook is organized as follows:\n",
    "* Computing [policy functions](#pol-fun),\n",
    "* Computing the [stationary distribution](#stat-dist),\n",
    "* Characterizing the [steady-state equilibrium](#steady-state)."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Computing policy functions  <a id=\"pol-fun\"></a>[<font size=1>(back to menu)</font>](#Aiyagari)\n",
    "\n",
    "The EGM algorithm consists of iterating on policy functions up to convergence. Compared to standard VFI or PFI methods, the policy functions of EGM characterize the choices as a function of end-of-period (instead of beginning-of-period) savings. This simplifies the iteration over policy functions and has impact when computing the stationary distribution. Note that in our implementation, we will inverse them in the last stage, when characterizing the [steady-state equilibrium](#steady-state). \n",
    "\n",
    "Denoting by  $g_a$ and $g_c$ the policy functions for asset and consumption choices respectively, these functions are defined on the cartesian set of asset choices $\\times$ productitivy levels: $[-\\underline{a},\\infty)\\times \\mathcal Y$. For any $(a,y)\\in[-\\underline{a},\\infty)\\times \\mathcal Y$, $g_a(a,y)$ is the beginning-of-period savings \n",
    "\n",
    "In practice, the asset choice grid is discretized and is represented by the variable `aGrid::Vector{T}` of the `economy::Economy` variable -- of size`na`. The policy functions will be matrices of size `na⋅ny` (where `ny` is the cardinal of the set of productivity levels $\\mathcal Y$). Since policy function are functions of end-of-period savings,  `aGrid[ia]` will represent the end-of-period savings for an agent endowed with beginning-of-period wealth `ga[ia,iy]` and productivity level `ys[iy]`."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The EGM algorithm can be decomposed into the following steps:\n",
    "\n",
    "1. Start from an initial guess for the post-tax rate $R$ and post-tax wage $w$, and for the asset policy function $g_a$ (seen as a $n_a\\times n_y$ matrix here). The policy function is the beginning--of-period savings as a function of end-of-period savings.\n",
    "\n",
    "2. Compute the  [consumption policy function](#updating-gc) as a function of beginning-of-period savings (and not the opposite for EGM). This key for the next step. \n",
    "\n",
    "3. Iterate on the [asset policy function using the Euler equation](#updating-ga). Thanks to the previous step, it is straightfoward to obtain current period consumption as a function of end-of-period savings via the Euler equation. This is where EGM saves time compared to other methods. \n",
    "\n",
    "4. Stop if the previous update is minor, otherwise start at step 2 again. If stooped, [return the policy functions](#solve-EGM). \n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Consumption as a function of beginning-of-period savings   <a id=\"updating-gc\"></a>[<font size=1>(back to policy functions)</font>](#pol-fun)\n",
    "\n",
    "Computing consumption as a function of beginning-of-period savings is done using the individual budget constraint through the following steps:\n",
    "* We start from a beginning-of-period asset holdings $a = a_{\\text{Grid},i_a}$ (given by the asset grid) and a productivity level $y=\\mathcal Y_{i_y}$. The indices $(i_a,i_y)$ refer to the *beginning-of-period*. We are interested in consumption associated to $i_a$ (savings $a$)  and $i_y$ (productivity $y$), hence denoted by $c_{i_a,i_y}$.\n",
    "* We compute using linear interpolation the end-of-period savings, $a^\\prime_{i_a,i_y}=lin(a,g_{a,i_y},a_\\text{Grid})$, where $lin$ is the interpolration function, $a$ is the beginning-of-period savings, $g_{a,i_y}$ is the vector of beginning-of-period savings in state $i_y$ (do not forget EGM specificity), and $a_\\text{Grid}$ is the corresponding vector of end-of-period savings. The linear interpolation function is implemented by `interpLinear` from the [`Utils`](./Utils.ipynb) notebook. \n",
    "*  We then compute consumption using the budget constraint: $c  = -a^\\prime + Ra+w(yl_{\\text{supply},iy})^{1-\\tau} + T $ -- which exists and is unique. The root is found via `find_zero` of the `Roots`package.\n",
    "\n",
    "In our GHH case,  the labor supply function $l_{\\text{supply},iy}$, is independent of the consumption choice, which simplifies this computtaion.  Note that this function is implemented in `economy` and can be invoked as `l_supply(w,y)`. \n",
    "\n",
    "\n",
    "The update is invoked by: \n",
    "\n",
    "> `inverse_cEGM!(solution, economy)`, \n",
    "\n",
    "where:\n",
    "* `solution::AiyagariSolution` is a mutable `struct` `AiyagariSolution` which is updated by the function (and hence contains the guess value mentionned above);\n",
    "* `economy::Economy` is a immutable `struct` `Economy` which contains economy parameters.\n",
    "\n",
    "The function returns `nothing` but updates `solution` inplace.\n",
    "\n",
    "> ***Remark.*** The function inverses solely the consumption policy function, which is defined as a function of beginning-of-period savings. Labor and savings policy function are defined as a function of end-of-period savings (as standrad in EGM). This is only temporary and avoids defining a temporary consumption policy function. \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "function inverse_c_EGM!(solution::AiyagariSolution,\n",
    "                      economy::Economy)::Nothing\n",
    "    \n",
    "    @unpack Tt, u′,τ,l_supply,na,a_min,aGrid,ny,ys = economy\n",
    "    @unpack ga,R,w = solution\n",
    "    cs = similar(ga)\n",
    "\n",
    "    for iy = 1:ny\n",
    "        for ia = 1:na\n",
    "            cs[ia,iy] = (R*aGrid[ia] + w*(ys[iy]*l_supply(w,ys[iy]))^τ - \n",
    "                interpLinear(aGrid[ia], ga[:,iy], aGrid, na, a_min)[1] + Tt)\n",
    "        end\n",
    "    end\n",
    "    solution.gc = cs\n",
    "    return nothing\n",
    "end;"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Updating the saving policy function using the Euler equation <a id=\"updating-ga\"></a>[<font size=1>(back to policy functions)</font>](#pol-fun)\n",
    "\n",
    "Updating the saving policy function works as follows. We have from previous computation  (via the function `update_cEGM!`) the next period consumption, $(c^\\prime_{i_a,i_y})_{i_a=1,\\ldots,n_a,i_y=1,\\ldots,n_y}$, as a function of beginning-of-period savings. Using the Euler equation, we obtain that the present consumption is defined by:\n",
    "$$u^\\prime(c_{i_a,i_y}) = \\beta R \\sum_{j_y}u^\\prime(c^\\prime_{i_a,j_y}),$$\n",
    "where $c_{i_a,i_y}$ is a function of end-of-period savings. This computation is starightforward and only involves linear algebra.\n",
    "\n",
    "This is probably the *key* simplification offered by EGM. The index $i_a$ remains the same on both sides of the equation.  On the left hand-side, $i_a$ refers to the end-of-period savings in the current period, while on the right hand side, it refers to the beginning-of-period savings in the next period. Hence, $c_{i_a,i_y}$ is the current consumption as a function of end-of-period savings!\n",
    "\n",
    "The rest of the computation is straightfoward:\n",
    "* labor supply is obtained with $l_{supply,y}(c)$ as a function of end-of-period savings;\n",
    "* beginning-of-period savings are computed using the individual budget constraint as a function of end-of-period savings.\n",
    "\n",
    "This is invoked via:\n",
    "> `euler_backward_EGM!(solution, economy)`, \n",
    "\n",
    "where (as before):\n",
    "* `solution::AiyagariSolution` is a mutable `struct` `AiyagariSolution` which is updated by the function;\n",
    "* `economy::Economy` is a immutable `struct` `Economy` which contains economy parameters.\n",
    "\n",
    "The function updates the policy functions for savings `ga`, for consumption `gc`, and for labor supply `gl` (all of them in `solution`) using the Euler equations for labor and consumption. *All* policy functions are then expressed as a function of end-of-period savings."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "function euler_backward_EGM!(solution::AiyagariSolution,\n",
    "                       economy::Economy)::Nothing\n",
    "    \n",
    "    @unpack β,Tt,u′,inv_u′,l_supply,na,aGrid,ny,ys,Πy,τ = economy\n",
    "    \n",
    "    \n",
    "    inverse_c_EGM!(solution, economy)\n",
    "    @unpack gc,R,w = solution\n",
    "    # gc is c(a) (because of update_cEGM!)\n",
    "    \n",
    "    cs = similar(gc) \n",
    "    ls = similar(gc)    \n",
    "    as = similar(gc)    \n",
    "    \n",
    "    for ia = 1:na\n",
    "        for iy = 1:ny\n",
    "            ls[ia,iy] = l_supply(w,ys[iy]) \n",
    "           # as[ia,iy] = (aGrid[ia] + cs[ia,iy] - Tt - w*(ys[iy]*ls[ia,iy])^τ)/R                 \n",
    "        end\n",
    "    end\n",
    "\n",
    "    u′s_next = similar(gc)\n",
    "    u′s_next .= u′.(gc,ls)\n",
    "    Eu′s = similar(gc)\n",
    "    Eu′s .= β*R*u′s_next*Πy'\n",
    "    cs .= inv_u′.(Eu′s,ls) \n",
    "    \n",
    "    # cs is c(a') (because of next-period beg-of-period = current period end-of-period)\n",
    "    for ia = 1:na\n",
    "        for iy = 1:ny\n",
    "           # ls[ia,iy] = l_supply(w,ys[iy]) \n",
    "            as[ia,iy] = (aGrid[ia] + cs[ia,iy] - Tt - w*(ys[iy]*ls[ia,iy])^τ)/R                 \n",
    "        end\n",
    "    end\n",
    "    # as is a(a')\n",
    "    # cs is c(a')\n",
    "    # ls is l(a')\n",
    "    \n",
    "    # updates policy function in solution\n",
    "    solution.gc .= cs\n",
    "    solution.ga .= as\n",
    "    solution.gl .= ls\n",
    "    return nothing\n",
    "end;"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Computing policy functions <a id=\"solve-EGM\"></a>[<font size=1>(back to policy functions)</font>](#pol-fun)\n",
    "\n",
    "We finally iterate on policy functions until there is no more change in the savings policy function. It uses the function `euler_back_EGM!`. \n",
    "\n",
    "The function is invoked by:\n",
    "> `policy_fun_EGM!(solution, economy; tol::Float64=1e-6, maxiter::Int64=10000)`, \n",
    "\n",
    "where:\n",
    "* `solution::AiyagariSolution` is a mutable `struct` `AiyagariSolution` where all policy functions are functions of end-of-period savings;\n",
    "* `economy::Economy` is a immutable `struct` `Economy` which contains economy parameters;\n",
    "* `tol::Float64` is a precision criterion to stop the convergence process;\n",
    "* `maxiter::Int64` is a number of maximal repetitions (in case of non-convergence of policy function). \n",
    "\n",
    "The function stops when:\n",
    "* either the relative difference between the policy function `ga` and its update is below the threshold `tol` (more formally if $g_a$ is the updated policy function and $\\tilde g_a$ the former one, the criterion is $|g_{a,i_a,i_y} - \\tilde g_{a,i_a,i_y}|/(g_{a,i_a,i_y} + \\tilde g_{a,i_a,i_y}) < \\varepsilon_{tol}$ for all $i_a,i_y$);\n",
    "* or the number of iterations is above the number of repetitions `maxiter`.\n",
    "The output message is different in both cases."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "function policy_fun_EGM!(solution::AiyagariSolution,\n",
    "                   economy::Economy;\n",
    "                   tol::T=1e-8, maxiter::I=10000,print_step::I=100)::Nothing where {T<:Real,I<:Integer}\n",
    "    # ierates on policy functions until convergence\n",
    "    \n",
    "    as  = similar(solution.ga)\n",
    "    as .= solution.ga\n",
    "    i = 0\n",
    "    while true \n",
    "        i += 1\n",
    "        euler_backward_EGM!(solution, economy) #updates policy functions once\n",
    "        test = maximum(abs.(solution.ga .- as) ./ (\n",
    "                abs.(solution.ga) .+ abs.(as))) #computation of the convergence criterion\n",
    "        if test < tol\n",
    "            # convergence is reached\n",
    "            println(\"Solved in \",i,\" \",\"iterations\")\n",
    "            break\n",
    "        end\n",
    "        if i > maxiter\n",
    "            # maximal nb of iterations is reached (but no convergence)\n",
    "            println(\"Convergence not reached after \",i,\"iterations. Criterion = \", test)\n",
    "            break\n",
    "        end\n",
    "        if (i%print_step == 0) \n",
    "            println(\"iteration: \", i , \" \", maximum(test))\n",
    "            flush(stdout)\n",
    "        end\n",
    "        as .= solution.ga\n",
    "    end \n",
    "    return nothing\n",
    "end;"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Computing stationary distribution <a id=\"stat-dist\"></a>[<font size=1>(back to menu)</font>](#Aiyagari)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The stationary distribution is computed as the normalized eigenvector associated to the largest eigenvalue of the transition matrix. The computation thus involves two steps:\n",
    "1. the [transition matrix](#transition-mat),\n",
    "2. the [stationary distribution](#stat-dist-2) as an eigenvector."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## The transition matrix <a id=\"transition-mat\"></a>[<font size=1>(back to stationary distribution menu)</font>](#stat-dist)\n",
    "\n",
    "We compute the transition matrix $T$ over the product grid of savings and productivity.  The transition matrix is thus of size $n_an_y\\times n_an_y$. Each element of the matrix corresponds to the probability to switch from one position in the product grid (characterized by one position in the asset grid and another position on the productivity grid) to another position on the product grid (hence another positions on the asset and productivity grids). \n",
    "\n",
    "### The non-EGM case\n",
    "\n",
    "The standard (non-EGM) way to compute the transition matrix can be summarized as follows. We take as given from a starting position on the product grid characterized by indices $i_a$ and $i_y$ (for savings and productivity, respectively), as well as an arrival index $j_y$ on the productivity grid. The arrival index $j_a$ on the asset grid is endogenous and computed with the savings policy function using the current productivity level $\\mathcal Y_{i_y}$ and beginning-of-period savings $a_{\\text{Grid},i_a}$ (further precision below). Then, the transition probability between $(i_a,i_y)$ and $(j_a,j_y)$ is  given by the transition probability between productivity indices $i_y$ and $j_y$, $\\Pi^y_{i_y,j_y}$, using the productivity transition matrix $\\Pi^y$. \n",
    "\n",
    "#### How does ones *exactly* compute the endogenous index $j_a$ for the asset grid?\n",
    "\n",
    "The savings policy function gives the end-of-period savings $a^\\prime = g_{a,i_y,i_a}$ as a function on current productivity level $ \\mathcal Y_{i_y}$ and beginning-of-period savings $a_{\\text{Grid},i_a}$. However, there is no reason for $a^\\prime$ to be exactly one of the points of the asset grids and $a^\\prime$ actually lies between two grid points: there is $j_a$, such that $a_{\\text{Grid},j_a}\\le a^\\prime< a_{\\text{Grid},j_a+1}$. More precisely, we have to handle corner cases and such a $j_a$ does not exist if $a^\\prime \\ge a_{\\text{Grid},n_a}$ or $a^\\prime < a_{\\text{Grid},1}$. Hence, we define $j_a = 1$ if $a^\\prime < a_{\\text{Grid},1}$ or $j_a=\\min_j \\{a^\\prime \\ge a_{\\text{Grid},j}\\}$ otherwise.\n",
    "\n",
    "This means that the transition is actually from $(i_a,i_y)$ to  $(j_a,j_y)$ or $(j_a+1,j_y)$. We will assume that the probability to switch to $j_a$ or $j_a+1$ is attributed thanks to a linear approximation. Observe that:\n",
    "$$a^\\prime = \\underbrace{\\frac{a_{\\text{Grid},j_a+1}-a^\\prime}{a_{\\text{Grid},j_a+1}-a_{\\text{Grid},j_a}}}_{=p_a}a_{\\text{Grid},j_a} + \\underbrace{\\frac{a^\\prime-a_{\\text{Grid},j_a}}{a_{\\text{Grid},j_a+1}-a_{\\text{Grid},j_a}}}_{=1-p_{a}}a_{\\text{Grid},j_a+1},$$\n",
    "where $p_{a}$ and $1-p_{a}$ can be interpreted as probabilities since $0\\le p_{a}\\le1$. Put it otherwise, we make the assumption that $a^\\prime$ is the barycentre of $x_a$ and $x_{a+1}$ with repsective weights $a_{\\text{Grid},j_a+1} -a^\\prime$ and $a^\\prime - a_{\\text{Grid},j_a}$.\n",
    "\n",
    "The transition probability from $(i_a,i_y)$ to  $(j_a,j_y)$ is thus $\\Pi^y_{i_yj_y}$, the probability to switch from $i_y$ to $j_y$, times $p_a=\\frac{a_{\\text{Grid},j_a+1}-a^\\prime}{a_{\\text{Grid},j_a+1}-a_{\\text{Grid},j_a}}$ the probability to switch from $a^\\prime$ to $a_{i_a}$. Simiarly, the transition probability from $(i_a,i_y)$ to  $(j_a+1,j_y)$ is thus $\\Pi^y_{i_yj_y}\\times p_{a+1}$. \n",
    "\n",
    "\n",
    "#### Summary\n",
    "\n",
    "Let take as given indices $(i_a,i_y)$ and $j_y$ and define $a^\\prime = g_{a,i_y,i_a}$. If it exists, let $j_a$ be such that $a_{\\text{Grid},j_a}\\le a^\\prime< a_{\\text{Grid},j_a+1}$. Otherwise, if $a^\\prime< a_{\\text{Grid},1}$, let $j_a=1$ or if $a^\\prime \\ge a_{\\text{Grid},n_a}$, let $j_a=n_a-1$. Defining\n",
    "$$p_a = \\max(\\min(\\frac{a_{\\text{Grid},j_a+1}-a^\\prime}{a_{\\text{Grid},j_a+1}-a_{\\text{Grid},j_a}},1),0),$$\n",
    "we have:\n",
    "\\begin{align}\n",
    "\\Pi_{(i_a,i_y),(j_a,j_y)} &= p_a\\Pi^y_{i_yj_y}, \\\\\n",
    "\\Pi_{(i_a,i_y),(j_a+1,j_y)} &= (1-p_a)\\Pi^y_{i_yj_y}, \\\\\n",
    "\\forall j \\notin\\{j_a,j_a+1\\}, \\Pi_{(i_a,i_y),(j,j_y)}&= 0. \\\\\n",
    "\\end{align}\n",
    "\n",
    "Note that the bounds for $p_a$ come from the corner cases for $j_a$. \n",
    "\n",
    "### The EGM case\n",
    "\n",
    "This needs to be slightly adapted for EGM since the savings policy function characterizes beginning-of-period savings as a function of end-of-period savings. The idea is actually to include in the function the inversion of the policy function. Again we start from indices $(i_a,i_y)$ and $j_y$. Because of EGM, we use linear interporlation to obtain end-of-period savings: $a^\\prime=lin(a_{\\text{Grid},i_a},g_{a,i_y},a_\\text{Grid})$and the fact that the interpolation is performed between indices $j_a$ and $j_a+1$. The implementation is then exactly as in the no-EGM case, except that it uses this trick:\n",
    "$$\\frac{a_{\\text{Grid},j_a+1}-a^\\prime}{a_{\\text{Grid},j_a+1}-a_{\\text{Grid},j_a}}=\\frac{g_{i_y,j_a+1}-a_{\\text{Grid},i_a}}{g_{i_y,j_a+1}-g_{i_y,j_a}},$$\n",
    "which directly comes from the linear interpolation (hence the function that maps  $[g_{i_y,j_a},g_{i_y,j_a+1}]$ onto $[a_{\\text{Grid},j_a},a_{\\text{Grid},j_a+1}]$ is affine) and the fact that the interpolation index $j_a$ is the same for the x-grid (here, $g_{a,i_y}$) and the y-grid (here, $a_\\text{Grid}$). Observe that with the previous trick, the linear interpolation is actually only useful to compute the index $j_a$, while the quantity $a^\\prime$ is useless. \n",
    "\n",
    "#### The implementation\n",
    "\n",
    "\n",
    "There are two implementations, corresponding to the no-EGM and EGM cases. \n",
    "\n",
    "The no-EGM function is invoked by \n",
    "\n",
    "> `transitionMat(ga::Matrix{T},economy::Economy)::SparseMatrixCSC{T,I}`, \n",
    "\n",
    "where:\n",
    "* `ga::Matrix{T}` is the saving policy function (a `na`$\\times$ `ny` matrix defined on the product grid of assets $\\times$ productivity);\n",
    "\n",
    "* `economy::Economy` is the economy.\n",
    "\n",
    "The function returns a sparse matrix (of type `SparseMatrixCSC{T,I}` from the package `SparseArrays`) such that:\n",
    "* the sparse matrix is of size `na⋅ny`$\\times$`na⋅ny`;\n",
    "* a matrix element corresponds to the probability to switch from a pair of (savings, productivity level) to another pair of (savings, productivity level).\n",
    "\n",
    "The EGM function is invoked by \n",
    "\n",
    "> `transitionMat_EGM(ga::Matrix{T},economy::Economy)::SparseMatrixCSC{T,I}`, \n",
    "\n",
    "with the same specification as before.\n",
    "\n",
    "***Remark.*** In the no-EGM case, the policy funtion `ga` should express end-of-period savings as a function of beginning-of-period savings. In the EGM case, the policy funtion `ga` should express beginning-of-period savings as a function of end-of-period savings.\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# no-EGM case: a′ = ga(a)\n",
    "function transitionMat(ga::Matrix{T}, economy::Economy{T,T1,T2,T3,T4,T5,I})::SparseMatrixCSC{T,I} where {\n",
    "                            T<:Real,T1<:Function,T2<:Function,T3<:Function,\n",
    "                            T4<:Function,T5<:Function,I<:Int64}\n",
    "    @unpack na,a_min,aGrid,ny,Πy = economy\n",
    "    \n",
    "    trans  = spzeros(T, I, na*ny, na*ny)\n",
    "    p      = zero(T)\n",
    "    i_mat  = zero(I)\n",
    "    j_mat  = zero(I)\n",
    "    for ia = 1:na\n",
    "        for iy = 1:ny\n",
    "            ja = searchsortedlast(aGrid, ga[ia,iy]) \n",
    "            #ia is the largest index such that aGrid[ia]≤x (and hence xs[ia+1]>x). Returns 0 if x≤xs[1]. xs sorted.\n",
    "\n",
    "            #Adjust ia if x falls out of bounds of xs\n",
    "            if ja == 0\n",
    "                ja += 1\n",
    "            elseif (ja==na)\n",
    "                ja -= 1\n",
    "            end  \n",
    "            p = (ga[ia,iy]-aGrid[ja])/(aGrid[ja+1] - aGrid[ja])\n",
    "            p = min(max(p,zero(T)),one(T))\n",
    "            i_mat = (iy-1)*na\n",
    "            for jy = 1:ny\n",
    "                j_mat = (jy-1)*na\n",
    "                trans[i_mat+ia,j_mat+ja+1] += p * Πy[iy,jy]\n",
    "                trans[i_mat+ia,j_mat+ja]   += (one(T)-p) * Πy[iy,jy]\n",
    "            end\n",
    "        end\n",
    "    end   \n",
    "    return trans\n",
    "end;"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# EGM case: a = ga(a′)\n",
    "function transitionMat_EGM(ga::Matrix{T}, economy::Economy{T,T1,T2,T3,T4,T5,I})::SparseMatrixCSC{T,I} where {\n",
    "                            T<:Real,T1<:Function,T2<:Function,T3<:Function,T4<:Function,T5<:Function,I<:Int64}\n",
    "    @unpack na,a_min,aGrid,ny,Πy = economy\n",
    "    \n",
    "    trans  = spzeros(T, I, na*ny, na*ny)\n",
    "    p      = zero(T)\n",
    "    a′     = zero(T)\n",
    "    ja     = zero(I)\n",
    "    i_mat  = zero(I)\n",
    "    j_mat  = zero(I)\n",
    "    for ia = 1:na\n",
    "        for iy = 1:ny\n",
    "            a′, ja = interpLinear(aGrid[ia], ga[:,iy], aGrid, na, a_min) \n",
    "            p = (aGrid[ia] - ga[ja,iy])/(\n",
    "                    ga[ja+1,iy] - ga[ja,iy])\n",
    "            p = min(max(p,zero(T)),one(T))\n",
    "            i_mat = (iy-1)*na\n",
    "            for jy = 1:ny\n",
    "                j_mat = (jy-1)*na\n",
    "                trans[i_mat+ia,j_mat+ja+1] = p * Πy[iy,jy]\n",
    "                trans[i_mat+ia,j_mat+ja] = (one(T)-p) * Πy[iy,jy]\n",
    "            end\n",
    "        end\n",
    "    end   \n",
    "    return trans\n",
    "end;"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## The stationary distribution <a id=\"stat-dist-2\"></a>[<font size=1>(back to stationary distribution menu)</font>](#stat-dist)\n",
    "\n",
    "We compute the stationary distribution using the [transition matrix](#transition-mat) discussed above. The stationary distribution $\\pi_T$ of the transition matrix $T$ is the (row) vector of size $1\\times n_a\\cdot n_y$ that verifies:  $\\forall i,\\ \\pi_{T,i} \\ge  0$, $\\sum_i \\pi_{T,i} = 1$, and $\\pi_T\\cdot T = \\pi_T$. In other words, $\\pi_T$ is teh normalized eigenvector associated to the matrix $T^\\top$. If the matrix $T$ is irreducible and aperiodic, $\\pi_T$ exist and is unique. If $\\lim_{k\\rightarrow \\infty} T^k$ exists, then $\\lim_{k\\rightarrow \\infty} T^k = \\mathbf{1}\\pi_T$.\n",
    "\n",
    "In practice, we rely on the function `powm!` of the package `IterativeSolvers`. This function computes the largest eigenvalue (in absolute value) and the related  eigenvector of a matrix.\n",
    "\n",
    "The function actually computing the stationary distribution is invoked with: \n",
    "> `stationaryDist(M; tol::Float64=1e-16, maxiter::Int64=100000)`, \n",
    "\n",
    "where:\n",
    "* `M::SparseMatrixCSC{T,I}` is a (sparse) transition matrix that results from function `transitionMat`;\n",
    "* `tol::Float64 = 1e-6` is a precision criterion to stop the convergence process;\n",
    "* `maxiter::Int64=10000` is a number of maximal repetitions (in case of non-convergence of the computation).\n",
    "\n",
    " The function returns a vector $\\Pi$ of size `na⋅ny` that verifies $\\Pi M=\\Pi$ and is  stationary distribution -- that is known to exist. It is computed as the normalised eigenvector corresponding to the largest eigen value of the transition matrix -- which is $1$. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "function stationaryDist(tP::AbstractMatrix{T}; ωys::Vector{T}, nys::Vector{I}, \n",
    "        tol::T=1e-12, maxiter::I=100000)::Vector{T} where {T<:Real,I<:Integer}\n",
    "    \n",
    "    @assert norm(sum(ωys)-one(T))<tol\n",
    "    \n",
    "    ntP = size(tP)[2]\n",
    "    na  = trunc(Int,ntP/sum(nys))\n",
    "    tPs = Array{Matrix{T},1}(undef, length(ωys))\n",
    "    sD  = Array{T,1}(undef, ntP)\n",
    "    \n",
    "    for k in eachindex(ωys)\n",
    "        tPs[k] = Matrix(tP[(1+(k-1)*na*nys[k]):(k*na*nys[k]),(1+(k-1)*na*nys[k]):(k*na*nys[k])])\n",
    "        x      = ones(T,na*nys[k])\n",
    "        powm!(tPs[k]', x, maxiter = maxiter,tol = tol)\n",
    "        x ./= sum(x) \n",
    "        \n",
    "        \n",
    "        @assert dimStoch(tPs[k]) == 2 # the matrix is row-stoch\n",
    "        @assert norm(x'*tPs[k] - x') < ntP*tol\n",
    "        #@assert (norm(x - (tPs[k]^100000)[1,:]) < ntP*tol)\n",
    "        #@show norm(x'*tPs[k] - x') \n",
    "        #@show typeof(x), typeof(sD)\n",
    "        sD[(1+(k-1)*na*nys[k]):(k*na*nys[k])] = ωys[k]*copy(x)\n",
    "    end\n",
    "    return sD    \n",
    "end\n",
    "\n",
    "function stationaryDist(M::SparseMatrixCSC{T,I}; ωys::Vector{T}, nys::Vector{I}, \n",
    "        tol::T=1e-16, maxiter::I=100000)::Vector{T} where {T<:Real,I<:Integer}\n",
    "    return stationaryDist(Matrix(M), ωys=ωys, nys=nys, tol=tol, maxiter=maxiter)\n",
    "end;"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# The steady-state equilibrium  <a id=\"steady-state\"></a>[<font size=1>(back to menu)</font>](#Aiyagari)\n",
    "\n",
    "There are three functions:\n",
    "* the [computation](#computation-steady-state) of the steady-state equilibrium;\n",
    "* the [verification](#consistency-steady-state) that the previous solution is consistent;\n",
    "* the [description](#description-steady-state) of the model solution."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Computation of the steady-state equilibrium <a id=\"computation-steady-state\"></a>[<font size=1>(back to menu)</font>](#steady-state)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The function: \n",
    "> `steady(economy; tol::T=1e-8, maxiter::I=100000)` \n",
    "\n",
    "computes the steady-state solution of the Aiyigary model, where (as before):\n",
    "* `economy::Economy` is a immutable `struct` `Economy` which contains economy parameters;\n",
    "* `tol::T=1e-8` is a precision criterion to stop the convergence process;\n",
    "* `maxiter::I=100000` is a number of maximal repetitions (in case of non-convergence of computations). \n",
    "\n",
    "The function returns the steady-state allocation under the form of a mutable `struct` of type `AiyagariSolution{T,I}`. \n",
    "\n",
    "The function `steady`relies on previous functions, in particular `SolveEGM!` for computing steady-state policy functions and `stationaryDist` to compute the stationnary distribution."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "function steady(economy::Economy;print_step::I=50,\n",
    "                tolEGM::T=1e-8, maxiterEGM::I=100000, \n",
    "                tolSD::T=1e-16, maxiterSD::I=500000)::AiyagariSolution{T,I} where {T<:Real,I<:Integer}\n",
    "\n",
    "    @unpack β,α,δ,τ,Tt,τk,u′,l_supply,na,a_min,a_max,aGrid,ny,ys,ωys,nys = economy\n",
    "    solution = AiyagariSolution(economy)\n",
    "    \n",
    "    # computing steady-state policy function\n",
    "    policy_fun_EGM!(solution, economy, tol=tolEGM, maxiter=maxiterEGM, print_step=print_step)\n",
    "    \n",
    "    @unpack ga,R,w = solution\n",
    "    \n",
    "    resE = similar(ga)\n",
    "    as = similar(ga)  #policy rules as a function of beginning of period savings\n",
    "    ls = similar(ga)  #policy rules as a function of beginning of period savings\n",
    "    cs = similar(ga)  #policy rules as a function of beginning of period savings\n",
    "\n",
    "    # we 'invert' policy functions (to obtain policy rules as  a function of beginning of period savings)\n",
    "    err = zero(T)\n",
    "    for ia = 1:na\n",
    "        for iy = 1:ny\n",
    "            as[ia,iy] = interpLinear(aGrid[ia], ga[:,iy], aGrid, na, a_min)[1]\n",
    "            ls[ia,iy] = l_supply(w,ys[iy])\n",
    "            cs[ia,iy] = (R*aGrid[ia] - as[ia,iy] +  w*(ys[iy]*ls[ia,iy])^τ + Tt)\n",
    "        end\n",
    "    end\n",
    "    solution.ga = as\n",
    "    solution.gc = cs\n",
    "    solution.gl = ls\n",
    "    \n",
    "    # computing stationnary distribution\n",
    "    @assert norm(transitionMat_EGM(ga,economy) .- transitionMat(as,economy))<1e-12\n",
    "    solution.transitMat = transitionMat(as,economy)\n",
    "    solution.stationaryDist = reshape(stationaryDist(solution.transitMat,\n",
    "            ωys=ωys,nys=nys,tol=tolSD,maxiter=maxiterSD), na, ny)\n",
    " \n",
    "    # We compute aggregate quantities\n",
    "    solution.A = sum(solution.stationaryDist.*as) #aggregate savings\n",
    "    solution.C = sum(solution.stationaryDist.*cs) #aggregate consumption\n",
    "    solution.L = sum(solution.stationaryDist.*(repeat(ys,1,na)'.*ls)) #aggregate labor supply\n",
    "    solution.K = solution.L*(((R-1)/(1-τk)+δ)/α)^(one(T)/(α-one(T))) #aggregate capital\n",
    "    solution.Y = (solution.K)^α*(solution.L)^(1-α)\n",
    "\n",
    "    solution.B = (solution.A)-(solution.K)\n",
    "    solution.G = solution.Y- δ*solution.K - solution.C\n",
    "    # We check Euler equations by computing their residuals\n",
    "    solution.residEuler = u′.(cs,ls) - β*R*reshape(\n",
    "        solution.transitMat*reshape(u′.(cs,ls),na*ny,1),na,ny)\n",
    "\n",
    "    return solution\n",
    "end;"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Checking the consistency of the solution  <a id=\"consistency-steady-state\"></a>[<font size=1>(back to menu)</font>](#steady-state)\n",
    "\n",
    "\n",
    "The function: \n",
    "> check_solution(solution::AiyagariSolution, economy::Economy; noprint::Bool=false)::Nothing\n",
    "\n",
    "verifies that the solution of the Aiyagari model is internally consistent. The function returns `nothing` but raises a warning `@warn` when an inconsistency is found. Function's arguments are:\n",
    "* the Aiyagari solution `solution` of type `AiyagariSolution`,\n",
    "* the economy parameters `economy` of type `Economy`,\n",
    "* the optional Boolean `noprint` stating if positive outcomes should not be printed (`true`if no print) -- `warning` always apppear.\n",
    "\n",
    "The elements that are tested are the following ones (in this order):\n",
    "* the stationary distribution sums to 1;\n",
    "* the aggregate labor supply is consistent with labor policy functions;\n",
    "* capital is consistent with modified golden rule and aggregate labor supply;\n",
    "* aggregate savings are consistent with end-of-period savings (saving policy function);\n",
    "* aggregate savings are consistent with beginning-of-period savings (asset grid);\n",
    "* aggregate consumption is consistent with aggregate individual budget constraints;\n",
    "* government budget constraint holds;\n",
    "* economy resource constraint holds."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "function check_solution(solution::AiyagariSolution, economy::Economy; noprint::Bool=false)::Nothing\n",
    "    @unpack β,α,κ,δ,κ,τ,τk,Tt,na,a_min,aGrid,ny,ys,Sy = economy\n",
    "    @unpack ga,gc,gl,R,w,A,K,C,L,G,stationaryDist = solution\n",
    "    if !(sum(stationaryDist) ≈ 1.0)\n",
    "        @warn(\"error in stationary distribution. Sum: \", round(sum(stationaryDist),digits=4))\n",
    "    else\n",
    "        (!noprint)&&println(\"Passed: stationary distribution. Sum: \", round(sum(stationaryDist),digits=4))\n",
    "    end\n",
    "    if !(all(sum(stationaryDist,dims=1)[1,:] .≈ Sy))\n",
    "        @warn(\"error in stationary distribution vs Sy. ||∫ₐΛ(da,y) - Sy||∞: \", \n",
    "            round(norm(sum(stationaryDist,dims=1)[1,:] - Sy),digits=4))\n",
    "    else\n",
    "        (!noprint)&&println(\"Passed: stationary distribution  vs Sy. ||∫ₐΛ(da,y) - Sy||∞: \", \n",
    "            round(norm(sum(stationaryDist,dims=1)[1,:] - Sy),digits=4))\n",
    "    end\n",
    "    if !(sum(stationaryDist .* ((repeat(ys,1,na)' .* gl))) ≈ L)\n",
    "        @warn(\"error in aggregate labor supply. L=\", round(L,digits=4), \n",
    "            \"; ∫lᵢℓ(di)=\", round(sum(stationaryDist .* ((repeat(ys,1,na)' .* gl))),digits=4))\n",
    "    else\n",
    "        (!noprint)&&println(\"Passed: aggregate labor supply. L=\", round(L,digits=4), \n",
    "            \"; ∫lᵢℓ(di)=\", round(sum(stationaryDist .* ((repeat(ys,1,na)' .* gl))),digits=4))\n",
    "    end\n",
    "    if !(L*((1/β - (1-δ))/α)^(1/(α-1)) ≈ K)\n",
    "        @warn(\"error in capital. K=\", round(K,digits=4), \"; L*(K_FB/L_FB)=\", \n",
    "            round(L*((1/β - (1-δ))/α)^(1/(α-1)),digits=4))\n",
    "    else\n",
    "        (!noprint)&&println(\"Passed: capital. K=\", round(K,digits=4), \"; L*(K_FB/L_FB)=\", \n",
    "            round(L*((1/β - (1-δ))/α)^(1/(α-1)),digits=4))\n",
    "    end\n",
    "    if !(A ≈ sum(stationaryDist.*ga))\n",
    "        @warn(\"error in aggregate savings. A=\",round(A,digits=4), \n",
    "            \"; ∫aᵢℓ(di)=\",round(sum(stationaryDist.*ga),digits=4))\n",
    "    else\n",
    "        (!noprint)&&println(\"Passed: aggregate savings. A=\",round(A,digits=4), \n",
    "            \"; ∫aᵢ′ℓ(di)=\",round(sum(stationaryDist.*ga),digits=4))\n",
    "    end\n",
    "    if abs(A-sum(stationaryDist.*repeat(economy.aGrid,1,ny))) > na*ny*1e-10\n",
    "        @warn(\"error in convergence for savings. A=\",round(A,digits=4), \n",
    "            \"; ∫aᵢℓ(di)=\",round(sum(stationaryDist.*repeat(economy.aGrid,1,ny)),digits=4))\n",
    "    else\n",
    "        (!noprint)&&println(\"Passed: convergence for savings. A=\",round(A,digits=4), \n",
    "            \"; ∫aᵢℓ(di)=\",round(sum(stationaryDist.*repeat(economy.aGrid,1,ny)),digits=4))\n",
    "    end\n",
    "\n",
    "    Lτ = sum(stationaryDist.*((repeat(economy.ys',economy.na,1).*gl).^τ))\n",
    "    C_ = w*Lτ - A + Tt + R*sum(stationaryDist.*repeat(economy.aGrid,1,ny))\n",
    "    if !(C ≈ C_)\n",
    "        @warn(\"error in aggregate consumption. C=\",round(C,digits=4), \n",
    "            \"; ∫cᵢ ℓ(di)=\",round(C_,digits=4))\n",
    "    else\n",
    "        (!noprint)&&println(\"Passed: aggregate consumption. C=\",round(C,digits=4), \n",
    "            \"; ∫cᵢ ℓ(di)=\",round(C_,digits=4))\n",
    "    end\n",
    "    gov_bc = K^α * L^(1-α) - δ*K - (R-1)*A - w*Lτ - G\n",
    "    if (abs(gov_bc) > 1e-10)\n",
    "        @warn(\"error in govt budget constraint. Gap=\",round(gov_bc,sigdigits=4))\n",
    "    else\n",
    "        (!noprint)&&println(\"Passed: govt budget constraint. Gap=\",round(gov_bc,sigdigits=4))\n",
    "    end\n",
    "    rc = K^α * L^(1-α) - δ*K - C - G\n",
    "    if (abs(rc) > 1e-10)\n",
    "        @warn(\"error in resource constraint. Gap=\",round(rc,sigdigits=4))\n",
    "    else\n",
    "        (!noprint)&&println(\"Passed: resource constraint. Gap=\",round(rc,sigdigits=4))\n",
    "    end\n",
    "        \n",
    "    return nothing\n",
    "end"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Describing the model solution <a id=\"description-steady-state\"></a></a>[<font size=1>(back to menu)</font>](#steady-state)\n",
    "\n",
    "The following function computes a number of useful quantities (Gini, public-spending-to-GDP, etc.) characterizing the steady state. The function is invoked by: \n",
    "> `describe_solution(solution::AiyagariSolution, economy::Economy; calib::String=\"quarterly\")`, \n",
    "\n",
    "where:\n",
    "* `solution::AiyagariSolution` is the mutable `struct` containing the steady-state allocation;\n",
    "* `economy::Economy` is a immutable `struct` `Economy` which contains economy parameters;\n",
    "* `calib::String=\"quarterly\"` is a parameter used to adjust GDP to obtain standard debt- and public spending-to-GDP values. The function solely handles quarterly (if a 'q' in `calib`) or annual (otherwise) frequency.\n",
    "\n",
    "The function returns the description as a dictionary of type `Dict{String,T}`. The dictionary can be printed with  `print_dict` from [Utils](./Utils.ipynb).\n",
    "\n",
    "***Remark.*** Due to ex-ante heterogeneity "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "function describe_solution(solution::AiyagariSolution, economy::Economy; calib::String=\"quarterly\", compute_persistence=false)    \n",
    "    @unpack β,α,κ,δ,τk,κ,τ,Tt,u′,l_supply,na,a_min,aGrid,ny,ys,n_ex_ante,nys = economy\n",
    "    @unpack ga,gc,gl,Rt,wt,R,w,A,C,K,L,G,Y,B,stationaryDist,residEuler = solution\n",
    "    T = typeof(β)\n",
    "    \n",
    "    # Adjusting stocks to the calibration (quarterly or yearly)\n",
    "    stock_adj = occursin('q', lowercase(calib)) ? 4one(typeof(β)) : one(typeof(β))\n",
    "    \n",
    "    # Computing post-consumption-tax quantities\n",
    "    R_init = R\n",
    "    w_init = w\n",
    "    A_init = A\n",
    "    B_init = A_init - K\n",
    "    Lτ = sum(stationaryDist.*((repeat(economy.ys',economy.na,1).*gl).^τ))\n",
    "    \n",
    "    # Computing MPC\n",
    "    diff_gc = gc[2:end,:] - gc[1:end-1,:]\n",
    "    diff_ga = aGrid[2:end,:] - aGrid[1:end-1,:]\n",
    "    mpc = sum(diff_gc.*stationaryDist[1:end-1,:]./diff_ga)\n",
    "    \n",
    "    # Computing total fiscal revenues\n",
    "    capital_tax = τk*(Rt-1)*A_init\n",
    "    labor_tax   = wt*L - κ*wt^τ*Lτ\n",
    "    conso_tax   = 0.0\n",
    "    tot_tax     = capital_tax+labor_tax+conso_tax\n",
    "\n",
    "    # Computing credit-constrained agents\n",
    "    ind = findall(x->x<=1e-9,solution.ga[:])\n",
    "    dist = solution.stationaryDist[:]\n",
    "    share  = sum(dist[ind])\n",
    "    \n",
    "    toR = Dict{String,T}(\"01. Gini\" => Gini(ga, stationaryDist), \n",
    "                \"02. Debt-to-GDP, B/Y\"  => B_init/(stock_adj*Y), \n",
    "                \"03. Public spending-to-GDP, G/Y\"  => G/Y, \n",
    "                \"04. Aggregate consumption-to-GDP, C/Y\"  => C/Y,\n",
    "                \"05. Capital-to-GDP, K/Y\"  => K/(stock_adj*Y),\n",
    "                \"06. Investment-to-GDP, I/Y\"  => δ*K/Y,\n",
    "                \"07. Transfers-to-GDP, Tt/Y\" => Tt/Y,\n",
    "                \"08. Aggregate labor supply, L\" => L,\n",
    "                \"09. Average MPC\" => mpc,\n",
    "                \"10. Consumption tax-to-GDP\" => conso_tax/Y,\n",
    "                \"11. Labor tax-to-GDP\" => labor_tax/Y,\n",
    "                \"12. Capital tax-to-GDP\" => capital_tax/Y,\n",
    "                \"13. Total tax-to-GDP\" => tot_tax/Y,\n",
    "                \"14. Share of credit-constrained agents\" => share)\n",
    "\n",
    "    if compute_persistence\n",
    "        try         \n",
    "            rho_lab,sd_lab = Lab_Process(economy)\n",
    "            toR[\"15. Average Persistence of labor process\"] = rho_lab\n",
    "            toR[\"16. Standard Deviation of Innov.\"]         = sd_lab\n",
    "        catch\n",
    "            println(\"You need to install 'DataFrames' and 'RCall' to compute persistence.\")\n",
    "            println(\"They will not be computed (the rest of the computation is unaffected).\")\n",
    "            compute_persistence = false\n",
    "        end\n",
    "    end\n",
    "    return toR\n",
    "end"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Julia 1.10.5",
   "language": "julia",
   "name": "julia-1.10"
  },
  "language_info": {
   "file_extension": ".jl",
   "mimetype": "application/julia",
   "name": "julia",
   "version": "1.10.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
